2024.4.14 4次のルンゲクッタ法【torch】
code:rk4_torch.py
import torch as pt
import matplotlib.pyplot as plt
import numpy as np
def func_xd(t, x):
return A@x
def solve_rk4(t_span, x, h):
hist_x, hist_t = [], []
hist_x.append(x.flatten().detach().clone())
hist_t.append(t)
k1 = func_xd(t, x)
k2 = func_xd(t + h/2, x + k1*h/2)
k3 = func_xd(t + h/2, x + k2*h/2)
k4 = func_xd(t + h, x + k3*h)
x = x + h*(k1 + 2*k2 + 2*k3 + k4)/6
t = t + dt
return hist_t, np.stack(hist_x)
def solve_euler(t_span, x, h):
hist_x, hist_t = [], []
hist_x.append(x.flatten().detach().clone())
hist_t.append(t)
k1 = func_xd(t, x)
k2 = func_xd(t + h/2, x + k1*h/2)
k3 = func_xd(t + h/2, x + k2*h/2)
k4 = func_xd(t + h, x + k3*h)
x = x + h*(k1 + 2*k2 + 2*k3 + k4)/6
t = t + dt
return hist_t, np.stack(hist_x)
##
dt = 0.01 # 時間の刻み幅
m, c, k = 1.0, 1.0, 1.0 # 質量、減衰係数、ばね定数(バネ-マス-ダンパ系)
x_initial = 0.1], [0 # 初期状態x, xd ##
A = pt.tensor(0, 1], [-k/m, -c/m, dtype=pt.float)
B = pt.tensor(0], [1, dtype=pt.float)
x_initial = pt.tensor(x_initial)
times, result = solve_rk4(t_span, x_initial, dt)
plt.grid()
plt.plot(times, result:,0, lw=3, label='$x$') plt.plot(times, result:,1, lw=3, label='$\dot{x}$') plt.legend()
plt.show()